import torch.nn


def init_weights(m):

    if type(m) == torch.nn.Conv2d:
        torch.nn.init.xavier_normal_(m.weight)

    elif type(m) == torch.nn.BatchNorm2d:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias, 0.0)
